// SPDX-License-Identifier: MIT
// SPDX-FileCopyrightText: 2023 werbenhu
// SPDX-FileContributor: werbenhu

package digo

import (
	"fmt"
	"go/ast"
	"go/format"
	"go/token"
	"os"
	"path/filepath"
	"strings"
)

// newIdent creates a new ast.Ident with the given name.
func newIdent(name string) *ast.Ident {
	return &ast.Ident{Name: name}
}

// newSelectorExpr creates a new ast.SelectorExpr with the given package and name.
func newSelectorExpr(selector string) *ast.SelectorExpr {
	splitted := strings.Split(selector, ".")
	return &ast.SelectorExpr{
		X:   newIdent(splitted[0]),
		Sel: newIdent(splitted[1]),
	}
}

// newStarExpr creates a new ast.StarExpr with the given package and name.
func newStarExpr(selector string) *ast.StarExpr {
	splitted := strings.Split(selector, ".")
	return &ast.StarExpr{
		X: &ast.SelectorExpr{
			X:   newIdent(splitted[0]),
			Sel: newIdent(splitted[1]),
		},
	}
}

// newCommentGroup creates a new ast.CommentGroup with the given texts.
func newCommentGroup(texts []string) *ast.CommentGroup {
	comments := make([]*ast.Comment, len(texts))
	for i, text := range texts {
		comments[i] = &ast.Comment{
			Text: text,
		}
	}
	return &ast.CommentGroup{
		List: comments,
	}
}

// newCallExpr creates a new ast.CallExpr with the given function and arguments.
func newCallExpr(fn ast.Expr, args []ast.Expr) *ast.CallExpr {
	return &ast.CallExpr{
		Fun:  fn,
		Args: args,
	}
}

// newExprs creates a new slice of ast.Expr with the given expressions.
func newExprs(exprs ...ast.Expr) []ast.Expr {
	rets := make([]ast.Expr, len(exprs))
	copy(rets, exprs)
	return rets
}

// newBasicLit creates a new ast.BasicLit with the given value.
func newBasicLit(val string) *ast.BasicLit {
	return &ast.BasicLit{
		Kind:  token.STRING,
		Value: "\"" + val + "\"",
	}
}

// newErrCheckStmt creates a new error check statement for the if condition.
func newErrCheckStmt() ast.Stmt {
	return &ast.IfStmt{
		Cond: &ast.BinaryExpr{
			X:  newIdent("err"),
			Op: token.NEQ,
			Y:  newIdent("nil"),
		},
		Body: &ast.BlockStmt{
			List: []ast.Stmt{
				&ast.ExprStmt{
					X: newCallExpr(newIdent("panic"), []ast.Expr{
						newIdent("err"),
					}),
				},
			},
		},
	}
}

func newImportSpec(path, alias string) *ast.ImportSpec {
	spec := &ast.ImportSpec{
		Path: newBasicLit(path),
	}
	if len(alias) != 0 {
		spec.Name = &ast.Ident{
			Name: alias,
		}
	}
	return spec
}

func objName(prefix string) string {
	name := strings.ReplaceAll(prefix, ".", "_")
	name = strings.ReplaceAll(name, "/", "_")
	return name + "_obj"
}

// Generator is a code generator for dependency injection.
type Generator struct {
	Package         *DiPackage
	CalledInitFuncs []ast.Stmt          // Initialization functions for singleton objects
	Fset            *token.FileSet      // FileSet for token positions
	Decls           []ast.Decl          // Functions generated by the provider
	ImportSpecs     map[string]ast.Spec // Packages to be imported

	ManagerPackage    string
	RegisterFunction  string
	ProvideFunction   string
	GroupFunction     string
	GeneratedFileName string
}

// NewGenerator creates a new Generator with the given path, package name, and filename.
func NewGenerator(pkg *DiPackage) *Generator {
	return &Generator{
		Package:         pkg,
		CalledInitFuncs: make([]ast.Stmt, 0),
		Fset:            token.NewFileSet(),
		Decls:           make([]ast.Decl, 0),
		ImportSpecs:     make(map[string]ast.Spec),

		ManagerPackage:    "github.com/werbenhu/digo",
		RegisterFunction:  "digo.RegisterSingleton",
		ProvideFunction:   "digo.Provide",
		GroupFunction:     "digo.RegisterMember",
		GeneratedFileName: "digo.generated.go",
	}
}

// defineInjectStmts analyzes the inject annotation and generates the corresponding code segment based on the annotation information.
func (g *Generator) defineInjectStmts(inject *Injector) []ast.Stmt {
	stmts := make([]ast.Stmt, 0)

	// Add import statement if the package is specified.
	if len(inject.Pkg) > 0 {
		g.addImport(inject.Pkg, inject.Alias)
	}

	// Generate assignment statements for providing the object and handling the error.
	stmts = append(stmts,
		&ast.AssignStmt{
			Lhs: newExprs(newIdent(objName(inject.Param)), newIdent("err")),
			Tok: token.DEFINE,
			Rhs: newExprs(
				newCallExpr(
					newSelectorExpr(g.ProvideFunction),
					[]ast.Expr{newBasicLit(inject.ProviderId)},
				),
			),
		},
		newErrCheckStmt(),
		&ast.AssignStmt{
			Lhs: newExprs(newIdent(inject.Param)),
			Tok: token.DEFINE,
			Rhs: newExprs(&ast.TypeAssertExpr{
				X:    newIdent(objName(inject.Param)),
				Type: inject.Typ,
			}),
		})

	return stmts
}

// defineProviderFunc creates a provider's singleton initialization function and returns an ast.FuncDecl object.
func (g *Generator) defineProviderFunc(fn *DiFunc) *ast.FuncDecl {
	stmts := make([]ast.Stmt, 0)
	args := make([]ast.Expr, 0)

	// Generate function arguments and inject statements if there are injectors.
	for _, inject := range fn.Injectors {
		args = append(args, newIdent(inject.GetArgName()))
		stmts = append(stmts, g.defineInjectStmts(inject)...)
	}

	// Generate assignment statements for calling the provider function, defining the object, and registering it as a singleton.
	stmts = append(stmts, &ast.AssignStmt{
		Lhs: newExprs(newIdent(fn.providerObjName())),
		Tok: token.DEFINE,
		Rhs: newExprs(newCallExpr(newIdent(fn.Name), args)),
	}, &ast.ExprStmt{
		X: newCallExpr(newSelectorExpr(g.RegisterFunction), newExprs(
			newBasicLit(fn.ProviderId),
			newIdent(fn.providerObjName())),
		),
	})

	comments := []string{
		fmt.Sprintf("\n// %s registers the singleton object with ID %s into the DI object manager", fn.providerFuncName(), fn.ProviderId),
		fmt.Sprintf("// Now you can retrieve the singleton object by using `obj, err := digo.Provide(\"%s\")`.", fn.ProviderId),
		"// The obj obtained from the above code is of type `any`.",
		"// You will need to forcefully cast the obj to its corresponding actual object type.",
	}

	return &ast.FuncDecl{
		Doc:  newCommentGroup(comments),
		Name: newIdent(fn.providerFuncName()),
		Type: &ast.FuncType{},
		Body: &ast.BlockStmt{List: stmts},
	}
}

// defineProviderFuncs generates initialization functions for singleton objects.
func (g *Generator) defineProviderFuncs() {
	// Iterate over each provider and generate the initialization function for the singleton object.
	for _, fn := range g.Package.Funcs {
		if len(fn.ProviderId) > 0 {

			// Add the initialization function for the singleton object to the ast.File.
			// For example, if the provider's ID is "xxx", then we add the init_xxx() function to the AST.
			g.Decls = append(g.Decls, g.defineProviderFunc(fn))

			// The initialization function for the singleton object needs to be called in the init() function.
			// Store the functions that need to be called in init() in the callFuncsInInit slice.
			// Later when creating the init() function, we will call these initialization functions for the singleton objects.
			g.CalledInitFuncs = append(g.CalledInitFuncs, &ast.ExprStmt{
				X: newCallExpr(newIdent(fn.providerFuncName()), newExprs()),
			})
		}
	}
}

// defineGroupFuncs adds initialization functions for all group singleton objects to the AST.
func (g *Generator) defineGroupFuncs() {
	// Iterate over each member and generate the initialization function for the singleton object.
	for _, fn := range g.Package.Funcs {
		if len(fn.GroupId) > 0 {
			// Add the initialization function for the singleton object to the ast.File.
			// For example, if the provider's ID is "xxx", then we add the init_xxx() function to the AST.
			g.Decls = append(g.Decls, g.defineGroupFunc(fn))

			// The initialization function for the singleton object needs to be called in the init() function.
			// Store the functions that need to be called in init() in the callFuncsInInit slice.
			// Later when creating the init() function, we will call these initialization functions for the singleton objects.
			g.CalledInitFuncs = append(g.CalledInitFuncs, &ast.ExprStmt{
				X: newCallExpr(newIdent(fn.groupFuncName()), newExprs()),
			})
		}
	}
}

// addImport adds a package name to the AST object.
func (g *Generator) addImport(pkg string, alias string) {
	key := pkg + "_" + alias
	if _, ok := g.ImportSpecs[key]; !ok {
		g.ImportSpecs[key] = newImportSpec(pkg, alias)
	}
}

// defineGroupFunc creates a group's member initialization function and returns an ast.FuncDecl object.
func (g *Generator) defineGroupFunc(fn *DiFunc) *ast.FuncDecl {
	stmts := make([]ast.Stmt, 0)
	args := make([]ast.Expr, 0)

	if len(fn.ProviderId) > 0 {
		// Generate assignment statement for providing the member object and handling the error.
		stmts = append(stmts,
			&ast.AssignStmt{
				Lhs: newExprs(newIdent("member"), newIdent("err")),
				Tok: token.DEFINE,
				Rhs: newExprs(
					newCallExpr(
						newSelectorExpr(g.GroupFunction),
						[]ast.Expr{newBasicLit(fn.ProviderId)},
					),
				),
			},
			newErrCheckStmt(),
		)
	} else {
		// Generate arguments and inject statements for member initialization.
		for _, inject := range fn.Injectors {
			args = append(args, newIdent(inject.Param))
			stmts = append(stmts, g.defineInjectStmts(inject)...)
		}
		stmts = append(stmts, &ast.AssignStmt{
			Lhs: newExprs(newIdent("member")),
			Tok: token.DEFINE,
			Rhs: newExprs(newCallExpr(newIdent(fn.Name), args)),
		})
	}

	// Register the member object with the group.
	stmts = append(stmts, &ast.ExprStmt{
		X: newCallExpr(newSelectorExpr(g.GroupFunction), newExprs(
			newBasicLit(fn.GroupId),
			newIdent("member")),
		),
	})

	comments := []string{
		fmt.Sprintf("\n// Add a member object to group: %s", fn.GroupId),
		fmt.Sprintf("// Now you can retrieve the group's member objects by using `objs, err := digo.Members(\"%s\")`.", fn.GroupId),
		"// The objs obtained from the above code are of type `[]any`.",
		"// You will need to forcefully cast the objs to their corresponding actual object types.",
	}

	return &ast.FuncDecl{
		Doc:  newCommentGroup(comments),
		Name: newIdent(fn.groupFuncName()),
		Type: &ast.FuncType{},
		Body: &ast.BlockStmt{List: stmts},
	}
}

// defineInitFunc generates the code for the init() function as an ast.FuncDecl object.
func (g *Generator) defineInitFunc() {
	decl := &ast.FuncDecl{
		Doc: newCommentGroup([]string{
			"\n// init registers all providers in the current package into the DI object manager.",
		}),
		Name: newIdent("init"),
		Type: &ast.FuncType{},
		Body: &ast.BlockStmt{
			List: g.CalledInitFuncs,
		},
	}
	g.Decls = append(g.Decls, decl)
}

// genAllAstDecls combines all ast.Decl objects into g.decls, where import declarations come before function declarations.
func (g *Generator) genAllAstDecls() {
	importSpecs := make([]ast.Spec, 0)
	for _, spec := range g.ImportSpecs {
		importSpecs = append(importSpecs, spec)
	}

	g.Decls = append([]ast.Decl{&ast.GenDecl{
		Tok:   token.IMPORT,
		Specs: importSpecs,
	}}, g.Decls...)
}

// writeHeaderComment writes the header comment to the Go file.
func (g *Generator) writeHeaderComment(file *os.File) int {
	header := "\n//\n// This file is generated by digogen. Run 'digogen' to regenerate.\n//\n" +
		"// You can install this tool by running `go install github.com/werbenhu/digo/digogen`.\n" +
		"// For more details, please refer to https://github.com/werbenhu/digo. \n//\n"
	fmt.Fprintf(file, header)
	return len(header)
}

// output writes the generated AST structures to the Go code file.
func (g *Generator) output() error {
	if err := os.MkdirAll(g.Package.Folder, 0777); err != nil {
		return err
	}

	path := filepath.Join(g.Package.Folder, g.GeneratedFileName)
	os.Remove(path)

	file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0666)
	if err != nil {
		return err
	}

	g.genAllAstDecls()
	startPos := g.writeHeaderComment(file)
	dest := &ast.File{
		FileStart: token.Pos(startPos),
		Name: &ast.Ident{
			Name: g.Package.Name,
		},
		Decls: g.Decls,
	}

	ast.SortImports(g.Fset, dest)
	format.Node(file, g.Fset, dest)
	return nil
}

// Do converts the extracted providers and injectors in the current package into Go AST structures and outputs the code to a Go file.
func (g *Generator) Do() {
	g.addImport(g.ManagerPackage, "")
	g.defineProviderFuncs()
	g.defineGroupFuncs()
	g.defineInitFunc()
	g.output()
}
